# This code is adapted from libauc
# The data can be downloaded from http://download.cs.stanford.edu/deep/CheXpert-v1.0-small.zip
import numpy as np
import torch 
from torch.utils.data import Dataset
import torchvision.transforms as tfs
import cv2
from typing import Optional, Callable, Tuple, Any
from PIL import Image
import pandas as pd
import os.path 

class CheXpert(Dataset):
    '''
    Reference: 
        Adapted from:
        Large-scale Robust deep auc maximization: A new surrogate loss and empirical studies on medical image classification
        Zhuoning Yuan, Yan Yan, Milan Sonka, Tianbao Yang
        International Conference on Computer Vision (ICCV 2021)
    '''
    def __init__(self,  
                 image_root_path='',
                 image_size=224,
                 class_index=0, 
                 split_by_class_index=1,
                 split_by_class_value=0,
                 test_time: bool = False,
                 batch_size=1,
                 start_index=0,
                 steps_per_example=1,
                 use_frontal=True,
                 use_upsampling=False,
                 flip_label=False,
                 shuffle=True,
                 seed=123,
                 verbose=True,
                 upsampling_cols=['Cardiomegaly', 'Consolidation'],
                 train_cols=['Cardiomegaly', 'Edema', 'Consolidation', 'Atelectasis',  'Pleural Effusion'],
                 mode='train',
                 transform=None):
        # Just like in extended folder
        self.test_time = test_time
        self.batch_size = batch_size
        self.steps_per_example = steps_per_example
        self.start_index = start_index
        self.transform = transform
        assert mode in ['train', 'valid']
        self.mode = mode
        csv_path = os.path.join(image_root_path, f'{mode}.csv')
        # load data from csv
        self.df = pd.read_csv(csv_path)
        self.df['Path'] = self.df['Path'].str.replace('CheXpert-v1.0-small/', '')
        self.df['Path'] = self.df['Path'].str.replace('CheXpert-v1.0/', '')
        if use_frontal:
            self.df = self.df[self.df['Frontal/Lateral'] == 'Frontal']  
        
        self.df['Age'].fillna(0, inplace=True) 

        self.df['age_bin'] = (self.df['Age'] <= 55)
        self.df['gender_bin'] = (self.df['Sex']  == 'Female')
        # upsample selected cols
        if use_upsampling:
            assert isinstance(upsampling_cols, list), 'Input should be list!'
            sampled_df_list = []
            for col in upsampling_cols:
                print ('Upsampling %s...'%col)
                sampled_df_list.append(self.df[self.df[col] == 1])
            self.df = pd.concat([self.df] + sampled_df_list, axis=0)

        # impute missing values 
        train_cols.extend(['age_bin', 'gender_bin'])
        for col in train_cols:
            if col in ['Edema', 'Atelectasis']:
                self.df[col].replace(-1, 1, inplace=True)  
                self.df[col].fillna(0, inplace=True) 
            elif col in ['Cardiomegaly','Consolidation',  'Pleural Effusion']:
                self.df[col].replace(-1, 0, inplace=True) 
                self.df[col].fillna(0, inplace=True)
            elif col in ['No Finding', 'Enlarged Cardiomediastinum', 'Lung Opacity', 'Lung Lesion', 'Pneumonia', 'Pneumothorax', 'Pleural Other','Fracture','Support Devices']: # other labels
                self.df[col].replace(-1, 0, inplace=True) 
                self.df[col].fillna(0, inplace=True)
            else:
                self.df[col].fillna(0, inplace=True)
        if split_by_class_index < 5:
            self.df = self.df[self.df[train_cols[split_by_class_index]]==split_by_class_value]
        elif split_by_class_index == 5:
            # Split by age:
            if split_by_class_value == 0:
                self.df = self.df[self.df['Age'] >= 65]
            else:
                self.df = self.df[self.df['Age'] <= 55]
        elif split_by_class_index == 6:
            # Split by gender:
            if split_by_class_value == 0:
                self.df = self.df[self.df['Sex'] == 'Male']
            else:
                self.df = self.df[self.df['Sex'] == 'Female']
        self._num_images = len(self.df)
        
        # 0 --> -1
        if flip_label and class_index != -1: # In multi-class mode we disable this option!
            self.df.replace(0, -1, inplace=True)   
            
        # shuffle data
        if shuffle:
            data_index = list(range(self._num_images))
            np.random.seed(seed)
            np.random.shuffle(data_index)
            self.df = self.df.iloc[data_index]
        
        
        assert image_root_path != '', 'You need to pass the correct location for the dataset!'

        if class_index == -1: # 5 classes
            if verbose:
                print ('Multi-label mode: True, Number of classes: [%d]'%len(train_cols))
                print ('-'*30)
            self.select_cols = train_cols
            self.value_counts_dict = {}
            for class_key, select_col in enumerate(train_cols):
                class_value_counts_dict = self.df[select_col].value_counts().to_dict()
                self.value_counts_dict[class_key] = class_value_counts_dict
        else:       # 1 class
            self.select_cols = [train_cols[class_index]]  # this var determines the number of classes
            self.value_counts_dict = self.df[self.select_cols[0]].value_counts().to_dict()
        
        self.mode = mode
        self.class_index = class_index
        self.image_size = image_size
        # Modification: only take the ones that 

        self._images_list =  [image_root_path+path for path in self.df['Path'].tolist()]
        if class_index != -1:
            self._labels_list = self.df[train_cols].values[:, class_index].tolist()
        else:
            self._labels_list = self.df[train_cols].values.tolist()
        self._index_0 = (1 - np.array(self._labels_list)).nonzero()[0]        
        self._index_1 = np.array(self._labels_list).nonzero()[0]
        if class_index != -1:
            if flip_label:
                self.imratio = self.value_counts_dict[1]/(self.value_counts_dict[-1]+self.value_counts_dict[1])
                if verbose:
                    print ('-'*30)
                    print('Found %s images in total, %s positive images, %s negative images'%(self._num_images, self.value_counts_dict[1], self.value_counts_dict[-1] ))
                    print ('%s(C%s): imbalance ratio is %.4f'%(self.select_cols[0], class_index, self.imratio ))
                    print ('-'*30)
            else:
                self.imratio = self.value_counts_dict[1]/(self.value_counts_dict[0]+self.value_counts_dict[1])
                if verbose:
                    print ('-'*30)
                    print('Found %s images in total, %s positive images, %s negative images'%(self._num_images, self.value_counts_dict[1], self.value_counts_dict[0] ))
                    print ('%s(C%s): imbalance ratio is %.4f'%(self.select_cols[0], class_index, self.imratio ))
                    print ('-'*30)
        else:
            imratio_list = []
            for class_key, select_col in enumerate(train_cols):
                try:
                    imratio = self.value_counts_dict[class_key][1]/(self.value_counts_dict[class_key][0]+self.value_counts_dict[class_key][1])
                except:
                    if len(self.value_counts_dict[class_key]) == 1 :
                        only_key = list(self.value_counts_dict[class_key].keys())[0]
                        if only_key == 0:
                            self.value_counts_dict[class_key][1] = 0
                            imratio = 0 # no postive samples
                        else:    
                            self.value_counts_dict[class_key][1] = 0
                            imratio = 1 # no negative samples
                        
                imratio_list.append(imratio)
                if verbose:
                    #print ('-'*30)
                    print('Found %s images in total, %s positive images, %s negative images'%(self._num_images, self.value_counts_dict[class_key][1], self.value_counts_dict[class_key][0] ))
                    print ('%s(C%s): imbalance ratio is %.4f'%(select_col, class_key, imratio ))
                    print ()
                    #print ('-'*30)
            self.imratio = np.mean(imratio_list)
            self.imratio_list = imratio_list
            
            
    @property        
    def class_counts(self):
        return self.value_counts_dict
    
    @property
    def imbalance_ratio(self):
        return self.imratio

    @property
    def num_classes(self):
        return len(self.select_cols)
       
    @property  
    def data_size(self):
        return self._num_images 
    
    def image_augmentation(self, image):
        img_aug = tfs.Compose([tfs.RandomAffine(degrees=(-15, 15), translate=(0.05, 0.05), scale=(0.95, 1.05), fill=128)]) # pytorch 3.7: fillcolor --> fill
        image = img_aug(image)
        return image
    
    def __len__(self):
        if self.test_time:
            mult = self.steps_per_example * self.batch_size
            mult *= self._num_images
            return mult
        else:
            return self._num_images
    
    def __getitem__(self, idx):
        if self.test_time:
            idx = (idx // self.steps_per_example) + self.start_index
        if self.mode == 'train':
            # During training, rebalance the datasets
            if np.random.random() < 0.5:
                idx = np.random.choice(self._index_0)
            else:
                idx = np.random.choice(self._index_1)
        image = cv2.imread(self._images_list[idx], 0)
        if self.transform is None: 
            image = Image.fromarray(image)
            if self.mode == 'train' and not self.test_time:
                image = self.image_augmentation(image)
            image = np.array(image)
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
        if self.transform is None:
            # resize and normalize; e.g., ToTensor()
            image = cv2.resize(image, dsize=(self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)  
            image = image/255.0
            __mean__ = np.array([[[0.485, 0.456, 0.406]]])
            __std__ =  np.array([[[0.229, 0.224, 0.225]]]) 
            image = (image-__mean__)/__std__
            # Our modification:
            image = torch.Tensor(image.transpose((2, 0, 1)).astype(np.float32))
        else:
            image = self.transform(Image.fromarray(image))
        if self.test_time:
            samples = torch.stack([image for i in range(self.batch_size)], axis=0)
        else:
            samples = image
        if self.class_index != -1: # multi-class mode
            label = int(self._labels_list[idx])
        else:
            label = np.array(self._labels_list[idx]).reshape(-1).astype(np.int64)
        return samples, label

